import os
import sys
import os.path as osp
import torch
from ..smp import *


def get_gpu_num(model_name):
    model_name = model_name.lower()
    kws = {
        8: ["65b", "70b"],
        4: ["30b", "33b", "35b", "40b"],
        2: ["13b", "14b", "20b"],
        1: ["6b", "7b", "moss"],
    }
    for k in [8, 4, 2, 1]:
        for keyword in kws[k]:
            if keyword in model_name:
                return k
    return 8


validated_llms = [
    "internlm/internlm-chat-7b",
    "internlm/internlm-chat-7b-8k",
    "internlm/internlm-chat-20b",
    "Qwen/Qwen-7B-Chat",
    "Qwen/Qwen-14B-Chat",
    "THUDM/chatglm2-6b",
    "THUDM/chatglm2-6b-32k",
    "THUDM/chatglm3-6b",
    "THUDM/chatglm3-6b-32k",
    "baichuan-inc/Baichuan2-7B-Chat",
    "baichuan-inc/Baichuan2-13B-Chat",
    "lmsys/vicuna-7b-v1.5",
    "lmsys/vicuna-13b-v1.5",
    "meta-llama/Llama-2-7b-chat-hf",
]
Auto_model = ["chatglm"]


class HFChatModel:

    def _get_context_length(self, model, model_path):
        # By default, we use model.config.seq_length
        model_path = model_path.lower()
        if "baichuan" in model_path:
            context_window = model.config.model_max_length
        elif "internlm" in model_path or "llama" in model_path:
            context_window = model.config.max_position_embeddings
        elif "vicuna" in model_path:
            context_window = model.generation_config.max_length
        else:
            # chatglm & qwen
            context_window = model.config.seq_length
        return context_window

    def _get_context_length_robust(self, model, model_path):
        try:
            context_window = self._get_context_length(model, model_path)
            return context_window
        except Exception as err:
            self.logger.critical(f"{type(err)}: {err}")
            self.logger.critical(
                "Failed to extract context_window information from config / generation_config. "
                "Please read the above code and check if the logic works for you model path"
            )
            raise NotImplementedError

    def __init__(self, model_path, system_prompt: str = None, **kwargs):

        self.logger = get_logger("HFChatModel")
        if "vicuna" in model_path.lower():
            try:
                from fastchat.model import get_conversation_template
            except Exception as err:
                self.logger.critical("Please install fastchat first to use vicuna. ")
                raise err

        self.explicit_device = kwargs.pop("device", None)

        if self.explicit_device is None:
            # If CUDA_VISIBLE_DEVICES is not properly set
            if (
                "CUDA_VISIBLE_DEVICES" not in os.environ
                or os.environ["CUDA_VISIBLE_DEVICES"] == "0,1,2,3,4,5,6,7"
            ):
                num_gpu = get_gpu_num(model_path)
                gpu_offset = kwargs.pop("gpu_offset", 0)
                cuda_visible_devices = ",".join(
                    [str(i) for i in range(gpu_offset, gpu_offset + num_gpu)]
                )
                os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices

        from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
        from transformers.generation import GenerationConfig

        if model_path not in validated_llms:
            self.logger.warning(
                f"{model_path} not in validated LLMs, may have inference troubles. "
            )

        self.model_path = model_path
        if listinstr(Auto_model, model_path):
            LoadModel = AutoModel
        else:
            LoadModel = AutoModelForCausalLM

        assert osp.exists(model_path) or len(model_path.split("/")) == 2

        device = self.explicit_device if self.explicit_device else "auto"

        precision = {}
        if "internlm-chat-7b" in model_path:
            precision = {"torch_dtype": torch.float16}
        elif "internlm-chat-20b" in model_path:
            precision = {"torch_dtype": torch.bfloat16}

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_path, trust_remote_code=True
        )
        model = LoadModel.from_pretrained(
            model_path, trust_remote_code=True, device_map="cpu", **precision
        )
        model = model.eval()

        if device != "cpu":
            model = model.to(f"cuda:{device}" if isinstance(device, int) else "cuda")
        try:
            model.generation_config = GenerationConfig.from_pretrained(
                model_path, trust_remote_code=True, device_map=device
            )
        except Exception as err:
            self.logger.warning(f"{type(err)}: {err}")

        torch.cuda.empty_cache()
        self.model = model
        self.context_length = self._get_context_length_robust(
            model=model, model_path=model_path
        )
        self.answer_buffer = 192
        self.system_prompt = system_prompt
        for k, v in kwargs.items():
            self.logger.info(
                f"Following args will be used for generation (If not set specifically), {k}: {v}. "
            )
        self.kwargs = kwargs

    def generate_str(self, input, **kwargs):
        if "baichuan" in self.model_path.lower():
            messages = []
            messages.append({"role": "user", "content": input})
            resp = self.model.chat(self.tokenizer, messages, **kwargs)
        elif "vicuna" in self.model_path.lower():
            from fastchat.model import get_conversation_template

            conv = get_conversation_template("vicuna")
            conv.append_message(conv.roles[0], input)
            conv.append_message(conv.roles[1], None)
            prompt = conv.get_prompt()
            inputs = self.tokenizer([prompt], return_tensors="pt")
            if torch.cuda.is_available():
                for k in inputs:
                    inputs[k] = inputs[k].cuda()

            params = dict(
                do_sample=True,
                temperature=0.7,
                repetition_penalty=1.0,
                max_new_tokens=512,
            )
            params.update(self.kwargs)
            params.update(kwargs)
            outputs = self.model.generate(**inputs, **params)
            resp = self.tokenizer.decode(
                outputs[0][len(inputs["input_ids"][0]) :],
                skip_special_tokens=True,
                spaces_between_special_tokens=False,
            )

        else:
            params = self.kwargs
            params.update(kwargs)
            resp, _ = self.model.chat(self.tokenizer, input, history=[], **params)

        return resp

    def length_ok(self, inputs):
        tot = (
            len(self.tokenizer.encode(self.system_prompt))
            if self.system_prompt is not None
            else 0
        )
        for s in inputs:
            tot += len(self.tokenizer.encode(s))
        return tot + self.answer_buffer < self.context_length

    def generate_list(self, full_inputs, offset=0, **kwargs):
        assert isinstance(full_inputs, list)

        inputs = full_inputs[offset:]
        if not self.length_ok(inputs):
            return self.chat(full_inputs, offset + 1)

        model_path = self.model_path.lower()

        if sum([x in model_path for x in ["baichuan"]]):
            input_msgs = []
            if self.system_prompt is not None:
                input_msgs.append(dict(role="user", content=self.system_prompt))
            if len(inputs):
                assert isinstance(inputs, list) and isinstance(inputs[0], str)
                roles = (
                    ["user", "assistant"]
                    if len(inputs) % 2 == 1
                    else ["assistant", "user"]
                )
                roles = roles * len(inputs)
                for role, msg in zip(roles, inputs):
                    input_msgs.append(dict(role=role, content=msg))
            response = self.model.chat(self.tokenizer, input_msgs)
        elif sum([x in model_path for x in ["vicuna"]]):
            from fastchat.model import get_conversation_template

            conv = get_conversation_template("vicuna")
            assert isinstance(inputs, list) and isinstance(inputs[0], str)
            if len(inputs) % 2 == 1:
                if self.system_prompt is not None:
                    conv.append_message(conv.roles[0], self.system_prompt)
                for i in range(len(inputs) // 2):
                    conv.append_message(conv.roles[0], inputs[2 * i])
                    conv.append_message(conv.roles[1], inputs[2 * i + 1])
            else:
                assert self.system_prompt is not None
                conv.append_message(conv.roles[0], self.system_prompt)
                conv.append_message(conv.roles[1], inputs[0])
                for i in range(len(inputs) // 2 - 1):
                    conv.append_message(conv.roles[0], inputs[2 * i + 1])
                    conv.append_message(conv.roles[1], inputs[2 * i + 2])
            conv.append_message(conv.roles[0], inputs[-1])
            conv.append_message(conv.roles[1], None)
            prompt = conv.get_prompt()
            inputs = self.tokenizer([prompt], return_tensors="pt")
            if torch.cuda.is_available():
                for k in inputs:
                    inputs[k] = inputs[k].cuda()

            params = dict(
                do_sample=True,
                temperature=0.7,
                repetition_penalty=1.0,
                max_new_tokens=512,
            )
            params.update(self.kwargs)
            params.update(kwargs)

            outputs = self.model.generate(**inputs, **params)
            response = self.tokenizer.decode(
                outputs[0][len(inputs["input_ids"][0]) :],
                skip_special_tokens=True,
                spaces_between_special_tokens=False,
            )
            response = response.lstrip("\n")
        else:
            # The default option, support internlm, chatglm, qwen
            history, msg = [], None
            if len(inputs) % 2 == 1:
                if self.system_prompt is not None:
                    history = [(self.system_prompt, "")]
                for i in range(len(inputs) // 2):
                    history.append((inputs[2 * i], inputs[2 * i + 1]))
            else:
                assert self.system_prompt is not None
                history = [(self.system_prompt, inputs[0])]
                for i in range(len(inputs) // 2 - 1):
                    history.append((inputs[2 * i + 1], inputs[2 * i + 2]))
            msg = inputs[-1]

            params = self.kwargs
            params.update(kwargs)
            response, _ = self.model.chat(
                self.tokenizer, msg, history=history, **params
            )

        return response, offset

    def generate(self, inputs, **kwargs):
        if isinstance(inputs, str):
            return self.generate_str(inputs, **kwargs)
        elif isinstance(inputs, list):
            return self.generate_list(inputs, **kwargs)
